{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TF-TRT Keras UNet Segmentation Example"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we are going to optimize a UNet-style segmentation model from the official Keras examples! \n",
    "\n",
    "You can find the implementation here: https://keras.io/examples/vision/oxford_pets_image_segmentation/\n",
    "\n",
    "Segmentation is a great demonstration for TensorRT as it tends to be very heavy on convolutional layers, which accelerate well. This is especially true of UNet, which consists entirely of convolutional layers."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's make sure our GPUs are properly configured and visible:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fri Jan 29 23:22:33 2021       \n",
      "+-----------------------------------------------------------------------------+\n",
      "| NVIDIA-SMI 450.80.02    Driver Version: 450.80.02    CUDA Version: 11.1     |\n",
      "|-------------------------------+----------------------+----------------------+\n",
      "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
      "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
      "|                               |                      |               MIG M. |\n",
      "|===============================+======================+======================|\n",
      "|   0  Tesla V100-DGXS...  On   | 00000000:07:00.0 Off |                    0 |\n",
      "| N/A   43C    P0    62W / 300W |    125MiB / 16155MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   1  Tesla V100-DGXS...  On   | 00000000:08:00.0 Off |                    0 |\n",
      "| N/A   43C    P0    41W / 300W |      6MiB / 16158MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   2  Tesla V100-DGXS...  On   | 00000000:0E:00.0 Off |                    0 |\n",
      "| N/A   42C    P0    40W / 300W |      6MiB / 16158MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   3  Tesla V100-DGXS...  On   | 00000000:0F:00.0 Off |                    0 |\n",
      "| N/A   43C    P0    39W / 300W |      6MiB / 16158MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "                                                                               \n",
      "+-----------------------------------------------------------------------------+\n",
      "| Processes:                                                                  |\n",
      "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
      "|        ID   ID                                                   Usage      |\n",
      "|=============================================================================|\n",
      "+-----------------------------------------------------------------------------+\n"
     ]
    }
   ],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Remember to sucessfully deploy a TensorRT model, you have to make __five key decisions__:\n",
    "\n",
    "1. __What format should I save my model in?__\n",
    "2. __What TensorRT tool or integration am I using to convert my model?__\n",
    "3. __What batch size(s) am I running inference at?__\n",
    "4. __What precision am I running inference at?__\n",
    "5. __What TensorRT runtime am I targeting?__\n",
    "\n",
    "Let's try converting our segmentation model!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. What format should I save my model in?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, some setup:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir -p tmp_savedmodels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we will download a specific implementation of U-Net from the Keras examples:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2021-01-29 23:22:36--  https://raw.githubusercontent.com/keras-team/keras-io/cd6201c1bfa37625f503f51e8fd3c572666770e4/examples/vision/oxford_pets_image_segmentation.py\n",
      "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.40.133\n",
      "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.40.133|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 7422 (7.2K) [text/plain]\n",
      "Saving to: ‘oxford_pets_image_segmentation.py’\n",
      "\n",
      "oxford_pets_image_s 100%[===================>]   7.25K  29.9KB/s    in 0.2s    \n",
      "\n",
      "2021-01-29 23:22:37 (29.9 KB/s) - ‘oxford_pets_image_segmentation.py’ saved [7422/7422]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!wget -O oxford_pets_image_segmentation.py https://raw.githubusercontent.com/keras-team/keras-io/cd6201c1bfa37625f503f51e8fd3c572666770e4/examples/vision/oxford_pets_image_segmentation.py"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The example has some unnecessary setup lines, we can pull out just the model itself using sed:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "!sed -n '70,74 p; 108,170 p' oxford_pets_image_segmentation.py > unet_model.py"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Last, we import and save our model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "img_size = (224, 224)\n",
    "num_classes = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n",
      "INFO:tensorflow:Assets written to: tmp_savedmodels/segment_model/assets\n"
     ]
    }
   ],
   "source": [
    "from unet_model import get_model\n",
    "\n",
    "model = get_model(img_size, num_classes)\n",
    "\n",
    "model_dir = \"tmp_savedmodels/segment_model\"\n",
    "model.save(model_dir) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. What batch size(s) am I running inference at?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will create a dummy batch of size 32:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "dummy_input = np.zeros((32, img_size[0], img_size[1], 3))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can test our original model, before any optimization:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(32, 224, 224, 10)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prediction = model.predict(dummy_input)\n",
    "\n",
    "prediction.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. What precision am I running inference at?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will stick with the same FP32 precision used during training:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "PRECISION = \"FP32\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. What TensorRT tool or integration am I using to convert my model?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will use our example TF-TRT based ModelOptimizer wrapper:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from helper import ModelOptimizer\n",
    "\n",
    "model_opt = ModelOptimizer(model_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Convert to our target precision, saving the result in a new SavedModel:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Linked TensorRT version: (7, 2, 1)\n",
      "INFO:tensorflow:Loaded TensorRT version: (7, 2, 2)\n",
      "INFO:tensorflow:Loaded TensorRT 7.2.2 and linked TensorFlow against TensorRT 7.2.1. This is supported because TensorRT  minor/patch upgrades are backward compatible\n",
      "INFO:tensorflow:Could not find TRTEngineOp_0_1 in TF-TRT cache. This can happen if build() is not called, which means TensorRT engines will be built and cached at runtime.\n",
      "INFO:tensorflow:Could not find TRTEngineOp_0_10 in TF-TRT cache. This can happen if build() is not called, which means TensorRT engines will be built and cached at runtime.\n",
      "INFO:tensorflow:Could not find TRTEngineOp_0_11 in TF-TRT cache. This can happen if build() is not called, which means TensorRT engines will be built and cached at runtime.\n",
      "INFO:tensorflow:Could not find TRTEngineOp_0_2 in TF-TRT cache. This can happen if build() is not called, which means TensorRT engines will be built and cached at runtime.\n",
      "INFO:tensorflow:Could not find TRTEngineOp_0_12 in TF-TRT cache. This can happen if build() is not called, which means TensorRT engines will be built and cached at runtime.\n",
      "INFO:tensorflow:Could not find TRTEngineOp_0_5 in TF-TRT cache. This can happen if build() is not called, which means TensorRT engines will be built and cached at runtime.\n",
      "INFO:tensorflow:Could not find TRTEngineOp_0_3 in TF-TRT cache. This can happen if build() is not called, which means TensorRT engines will be built and cached at runtime.\n",
      "INFO:tensorflow:Could not find TRTEngineOp_0_6 in TF-TRT cache. This can happen if build() is not called, which means TensorRT engines will be built and cached at runtime.\n",
      "INFO:tensorflow:Could not find TRTEngineOp_0_7 in TF-TRT cache. This can happen if build() is not called, which means TensorRT engines will be built and cached at runtime.\n",
      "INFO:tensorflow:Could not find TRTEngineOp_0_4 in TF-TRT cache. This can happen if build() is not called, which means TensorRT engines will be built and cached at runtime.\n",
      "INFO:tensorflow:Could not find TRTEngineOp_0_8 in TF-TRT cache. This can happen if build() is not called, which means TensorRT engines will be built and cached at runtime.\n",
      "INFO:tensorflow:Could not find TRTEngineOp_0_9 in TF-TRT cache. This can happen if build() is not called, which means TensorRT engines will be built and cached at runtime.\n",
      "INFO:tensorflow:Could not find TRTEngineOp_0_0 in TF-TRT cache. This can happen if build() is not called, which means TensorRT engines will be built and cached at runtime.\n",
      "INFO:tensorflow:Assets written to: tmp_savedmodels/segment_model_FP32/assets\n",
      "conversion complete! prediction shape: (32, 224, 224, 10)\n"
     ]
    }
   ],
   "source": [
    "opt_trt = model_opt.convert(model_dir+'_FP32', precision=PRECISION)\n",
    "print(\"conversion complete! prediction shape:\", opt_trt.predict(dummy_input).shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. What TensorRT runtime am I targeting?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will stick to our TF-TRT/Tensorflow runtime:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warming up...\n",
      "(32, 224, 224, 10)\n",
      "(32, 224, 224, 10)\n",
      "Done warming up!\n"
     ]
    }
   ],
   "source": [
    "print(\"Warming up...\")\n",
    "\n",
    "print(model.predict(dummy_input).shape)\n",
    "print(opt_trt.predict(dummy_input).shape)\n",
    "\n",
    "print(\"Done warming up!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Performance Comparisons:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "182 ms ± 41.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
     ]
    }
   ],
   "source": [
    "%%timeit\n",
    "\n",
    "preds = model.predict(dummy_input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "22.3 ms ± 95.9 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
     ]
    }
   ],
   "source": [
    "%%timeit\n",
    "\n",
    "preds = opt_trt.predict(dummy_input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
